#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import itertools
import os

PREFIXES = {
    "dp4a": [
        ("batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4", True),
        ("batch_conv_bias_int8_gemm_ncdiv4hw4", False),
        ("batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128", False),
    ]
}

ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")}

BIASES = {
    1: ("PerElementBiasVisitor", "_per_elem"),
    2: ("PerChannelBiasVisitor", "_per_chan"),
}

SUFFIXES = {"dp4a": [""], "imma": [""]}


def main():
    parser = argparse.ArgumentParser(
        description="generate cuda batch conv bias (dp4a/imma) kern impl files",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--type",
        type=str,
        choices=["dp4a", "imma"],
        default="dp4a",
        help="generate cuda conv bias kernel file",
    )
    parser.add_argument("output", help="output directory")
    args = parser.parse_args()

    if not os.path.isdir(args.output):
        os.makedirs(args.output)

    inst = """
template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS, 
        IConvEpilogue<Activation<megdnn::param_enumv::BatchConvBias::NonlineMode::ACTIVATION>>>(
        const int8_t* d_src, 
        const int8_t* d_filter, WORKSPACE 
        BIAS bias, 
        IConvEpilogue<Activation<megdnn::param_enumv::BatchConvBias::NonlineMode::ACTIVATION>> epilogue, 
        const ConvParam& param, 
        float alpha, 
        float beta, 
        cudaStream_t stream);"""

    for prefix in PREFIXES[args.type]:
        for suffix in SUFFIXES[args.type]:
            for _, act in ACTIVATIONS.items():
                has_workspace = prefix[1]
                bias = BIASES[2]
                fname = "{}{}{}{}.cu".format(prefix[0], suffix, bias[1], act[1])
                fname = os.path.join(args.output, fname)
                with open(fname, "w") as fout:
                    w = lambda s: print(s, file=fout)
                    w("// generated by gen_batch_cuda_conv_bias_kern_impls.py")
                    cur_inst = (
                        inst.replace("PREFIX", prefix[0])
                        .replace("SUFFIX", suffix)
                        .replace("BIAS", bias[0])
                        .replace("ACTIVATION", act[0])
                    )
                    if has_workspace:
                        cur_inst = cur_inst.replace("WORKSPACE", "\nint* d_workspace, ")
                    else:
                        cur_inst = cur_inst.replace("WORKSPACE", "")
                    w('#include "../{}{}.cuinl"'.format(prefix[0], suffix))
                    w(cur_inst)

                    print("generated {}".format(fname))
    os.utime(args.output)


if __name__ == "__main__":
    main()
